-
Notifications
You must be signed in to change notification settings - Fork 701
Implement sharding and device mesh debug tool #2328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fegin/78/base
Are you sure you want to change the base?
Conversation
**Why this PR** Understanding different DTensor sharding is import especially when doing full dtensor project. Creating this tool for debugging purpose. **What this PR does** This PR adds a sharding debug tool that captures and visualizes DTensor sharding information during training. When enabled via `debug.log_sharding_info=True`, it registers forward and backward hooks on all modules to record tensor placements, device mesh info, and shapes for one forward/backward pass. The tool outputs both a formatted ASCII text file and an interactive HTML visualization. **Limitation:** This tool can only track 1) module inputs/outputs and the gradients and 2) module states and the gradients. Any activation that generate by ops that is not a module can not be tracked. We will have to use TorchFunctionMode or TorchDispatchMode to do this. **For Reviewers:** UX functions (ASCII and html) are completely generated by Claude. I'm not an experienced frontend developer and didn't code review too much for those files. ``` NGPU=8 COMM_MODE=fake_backend CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --parallelism.tensor_parallel_degree=8 --debug.log_sharding_info ``` ghstack-source-id: bcc7749 Pull-Request: #2328
**Why this PR** Understanding different DTensor sharding is import especially when doing full dtensor project. Creating this tool for debugging purpose. **What this PR does** This PR adds a sharding debug tool that captures and visualizes DTensor sharding information during training. When enabled via `debug.log_sharding_info=True`, it registers forward and backward hooks on all modules to record tensor placements, device mesh info, and shapes for one forward/backward pass. The tool outputs both a formatted ASCII text file and an interactive HTML visualization. **Limitation:** This tool can only track 1) module inputs/outputs and the gradients and 2) module states and the gradients. Any activation that generate by ops that is not a module can not be tracked. We will have to use TorchFunctionMode or TorchDispatchMode to do this. **For Reviewers:** UX functions (ASCII and html) are completely generated by Claude. I'm not an experienced frontend developer and didn't code review too much for those files. ``` NGPU=8 COMM_MODE=fake_backend CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --parallelism.tensor_parallel_degree=8 --debug.log_sharding_info ``` ghstack-source-id: 5354c4e Pull-Request: #2328
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is general useful for all DTensor developers, not only titan developers (and we would expect infra developers would be interested in this tool, not model researchers). Should we put it in PyTorch , like flight_recorder?
Most our users also care about scaling, so this is useful for them too.
The reason why I don't put it in PyTorch, is because this tool is not general enough yet. The tool is likely to be polished continuously when we start to use it. After it is mature enough, we can upstream to PyTorch. |
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm wondering about the best way to land this. it seems nice. It's also a big pile of code.
- it's not obvious that it is torchtitan specific. should some/all of this go into torch itself?
is there a nice way to decouple a 'core' piece that lands in torch- perhaps
- a contextmanager that itself produces a well defined data structure
- a clearly separate plugin that takes the data and renders it
in this way, you can keep the html stuff out of tree
|
I think it partially addressed the pain point of sharding info not being integral part of nn.Module, in the way that we explicitly insert hooks to record what happens, somewhat similar to
Can we merge/iterate with one of those tools in pytorch? |
@tianyu-l cc., @wconstab @wwwjn The same answer to your questions as well. |
I think we could put it in a separate branch for now. And using html might be too burdensome as most of the info are structured, can we put it simply in a json file? That would also simplify the land process |
|
@wwwjn Let's put it in anther branch for now. But I'm not sure if I agree with the html part. html is more readable even though this PR also generates ascii file, I merely use it for programming verification purpose. Mostly I read html. tlparse also use html. Let's iterate in another branch to understand what's the best way to upstream the tool. |
|
@fegin I'm unsure how much work it would be to integrate into CommDebugMode. As you can see, commdebugmode does include some of the information you're already trying to output. Furthermore, commdebugmode uses noise levels to control how much information is output. In this case, you could either just create a different output function just for your output, or make your work the minimum output noise-level where no ops are shown, its just module sharding information. In addition, there technically was a html |
|
@anshul-si Yes, I saw that before :) But the main information I need is simply the input, state, and output sharding of a module. I only see state sharding. I guess one can try to infer from the aten order. But it is better to explicitly capture this information. I think one way going forward is to enhance CommDebugMode to capture |


Stack from ghstack (oldest at bottom):
Why this PR
Understanding different DTensor sharding is import especially when doing full dtensor project. Creating this tool for debugging purpose.
What this PR does
This PR adds a sharding debug tool that captures and visualizes DTensor sharding information during training. When enabled via
debug.log_sharding_info=True, it registers forward and backward hooks on all modules to record tensor placements, device mesh info, and shapes for one forward/backward pass. The tool outputs both a formatted ASCII text file and an interactive HTML visualization.Limitation:
This tool can only track 1) module inputs/outputs and the gradients and 2) module states and the gradients. Any activation that generate by ops that is not a module can not be tracked. We will have to use TorchFunctionMode or TorchDispatchMode to do this.
For Reviewers:
UX functions (ASCII and html) are completely generated by Claude. I'm not an experienced frontend developer and didn't code review too much html file.